import subprocess

import pytest

try:
    from cuda.bindings import driver as cuda
    from cuda.bindings import nvrtc
except ImportError:
    from cuda import cuda, nvrtc


def ASSERT_DRV(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError('Cuda Error: {}'.format(err))
    elif isinstance(err, nvrtc.nvrtcResult):
        if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
            raise RuntimeError('Nvrtc Error: {}'.format(err))
    else:
        raise RuntimeError('Unknown error type: {}'.format(err))


def getSMVersion():
    # Init
    err, = cuda.cuInit(0)
    ASSERT_DRV(err)

    # Device
    err, cuDevice = cuda.cuDeviceGet(0)
    ASSERT_DRV(err)

    # Get target architecture
    err, sm_major = cuda.cuDeviceGetAttribute(
        cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
        cuDevice)
    ASSERT_DRV(err)
    err, sm_minor = cuda.cuDeviceGetAttribute(
        cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
        cuDevice)
    ASSERT_DRV(err)

    return sm_major * 10 + sm_minor


# The default test cases for flash attention fmha that will be used in TRTLLM.
@pytest.mark.parametrize('d', [32, 40, 64, 72, 80, 96, 104, 128, 160, 192, 256],
                         ids=[
                             "head-size-32", "head-size-40", "head-size-64",
                             "head-size-72", "head-size-80", "head-size-96",
                             "head-size-104", "head-size-128", "head-size-160",
                             "head-size-192", "head-size-256"
                         ])
@pytest.mark.parametrize('s', [1024], ids=["seqlen-1024"])
@pytest.mark.parametrize('dtype', ["-fp16", "-bf16", "-fp16-fp32", "-e4m3"],
                         ids=["fp16", "bf16", "fp16-fp32", "e4m3"])
@pytest.mark.parametrize('flag', [
    "-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv",
    "-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks"
])
@pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"])
def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel):
    verbose = 0
    sm_version = getSMVersion()
    if flag == "-use-attention-sinks" and sm_version != 90:
        pytest.skip("use-attention-sinks is only supported on sm90 currently.")
    if sm_version == 90 and tiled_kernel == "-force-non-tiled":
        pytest.skip(
            "Tiled/non-tiled flags only make a difference to ampere-style kernels."
        )
    if sm_version == 70 and dtype != "-fp16":
        pytest.skip("Volta fmha only supports fp16 data type.")
    # looks like cublas doesn't support non-multiple-of-16 head sizes.
    if dtype == '-e4m3' and d in [40, 72, 104]:
        pytest.skip("cublas doesn't support non-multiple-of-16 head sizes.")
    # only ada/hopper support fp8 fmha currently.
    if dtype == '-e4m3' and sm_version not in [89, 90]:
        pytest.skip("only hopper supports fp8 fmha currently.")
    # ada fp8 fmha only supports non-tiled kernels currently.
    if dtype == '-e4m3' and sm_version == 89 and tiled_kernel == "":
        pytest.skip("ada fp8 fmha only supports non-tiled kernels currently.")
    # Known accuracy issue in this case.
    skip_dense_mask_test = False
    if d == 64 and dtype in ['-fp16-fp32', '-bf16'] and tiled_kernel == "":
        skip_dense_mask_test = True

    # use higher error tolerance for bf16 and e4m3.
    epsilon = ''
    if dtype == '-bf16':
        epsilon += ' -epsilon 0.03'
    elif dtype == '-fp16' and '-softcapping-scale-bmm1' in flag:
        epsilon += ' -epsilon 0.03'
    elif dtype == '-e4m3':
        epsilon += ' -epsilon 0.2'
    else:
        epsilon += ' -epsilon 0.02'

    # only generate d = 128 kernels with softcapping-scale-bmm1 support.
    if d != 128 and '-softcapping-scale-bmm1' in flag:
        pytest.skip(
            "Only d = 128 + softcapping-scale-bmm1 kernels are generated by default."
        )

    # force using non-tiled kernels for d = 64 + contiguous-q-kv flag.
    if d == 64 and flag == '-contiguous-q-kv' and sm_version < 90:
        flag += ' -force-non-tiled'

    # The sm89 e4m3 kernel has a bug with -s-q < 128. This bug will be tracked in the issue.
    if sm_version == 89 and dtype == "-e4m3":
        if "-s-q 63" in flag:
            pytest.skip("skipping chunk size 63 for sm89 e4m3 fmha.")
        if "softcapping-scale-bmm1" in flag:
            pytest.skip("skipping softcapping-scale-bmm1 for sm89 e4m3 fmha.")

    if not skip_dense_mask_test:
        subprocess.run(
            f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
            shell=True,
            check=True)
    subprocess.run(
        f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
        shell=True,
        check=True)
    subprocess.run(
        f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
        shell=True,
        check=True)
    if flag == '-contiguous-q-kv' or flag == '-paged-kv':
        subprocess.run(
            f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
            shell=True,
            check=True)
    # alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks.
    if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag:
        subprocess.run(
            f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
            shell=True,
            check=True)
    subprocess.run(
        f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -multi-query-attention -sliding-window-size 54 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}",
        shell=True,
        check=True)


# The test cases for sage attention.
@pytest.mark.parametrize('d', [80, 128], ids=["head-size-80", "head-size-128"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
def test_trtllm_sage_attention_fmha(d, s):
    sm_version = getSMVersion()
    if sm_version != 89 and sm_version != 90:
        pytest.skip("Sage attention only supports sm89 and sm90 currently.")

    # Ada.
    if sm_version == 89:
        subprocess.run(
            f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 16 -h 8 -d {d} -bf16 \
            -sage-block-q 64 -sage-block-k 32 -sage-block-v 32 -force-non-tiled",
            shell=True,
            check=True)


# The test cases for mla attention.
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
                         ids=["bf16", "e4m3", "e4m3-bf16"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
def test_trtllm_context_mla_attention_fmha(dtype, s):
    sm_version = getSMVersion()
    if sm_version < 90:
        pytest.skip("MLA kernels are only tested on sm90 and above currently.")

    # use higher error tolerance for bf16 and s = 4096.
    epsilon = ''
    if dtype == "-bf16" and s == 4096:
        epsilon += ' -epsilon 0.03'

    if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version not in [90, 120]:
        pytest.skip("FP8 MLAs are only supported on sm90 and sm120 currently.")

    # Context phase kernels, always use separate-q-k-v layout.
    subprocess.run(
        f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
        f"-causal-mask {epsilon} -separate-q-k-v",
        shell=True,
        check=True)

    # For chunked prefill, we need to enable -save-softmax (dtype: bf16, layout: separate-q-k-v).
    if dtype in ["-bf16", "-e4m3"]:
        # padding mask
        subprocess.run(
            f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
            f"{epsilon} -separate-q-k-v -save-softmax",
            shell=True,
            check=True)
        # causal mask
        subprocess.run(
            f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
            f"-causal-mask {epsilon} -separate-q-k-v -save-softmax",
            shell=True,
            check=True)


@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
                         ids=["bf16", "e4m3", "e4m3-bf16"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
@pytest.mark.parametrize('num_grouped_heads', [16, 32, 64, 128],
                         ids=[
                             "num-grouped-heads-16", "num-grouped-heads-32",
                             "num-grouped-heads-64", "num-grouped-heads-128"
                         ])
def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads):
    sm_version = getSMVersion()
    if sm_version < 90:
        pytest.skip("MLA kernels are only tested on sm90 and above currently.")

    # use higher error tolerance for bf16 and s = 4096.
    epsilon = ''
    if dtype == "-bf16" and s == 4096:
        epsilon += ' -epsilon 0.03'

    if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
        pytest.skip("FP8 MLAs are only supported on sm120 currently.")

    # Generation phase kernels.
    subprocess.run(
        f"bin/fmha.exe -v 0 -runs 1 -s-q 128 -min-s 1024 -s {s} -b 8 -h 1 -d 576 -dv 512 {dtype} \
        -paged-kv -num-grouped-heads {num_grouped_heads} -force-non-warp-specialization {epsilon}",
        shell=True,
        check=True)


# The test cases for saving softmax.
@pytest.mark.parametrize('mask', ["-causal-mask", ""],
                         ids=["causal-mask", "padding-mask"])
@pytest.mark.parametrize(
    's', [128, 256, 384, 512],
    ids=["seqlen-128", "seqlen-256", "seqlen-384", "seqlen-512"])
def test_trtllm_save_softmax(mask, s):
    subprocess.run(
        f"bin/fmha.exe -v 0 -runs 1 -s {s} -d 64 -min-s 1 -b 1 -h 4 -fp16 \
    {mask} -contiguous-q-kv -save-softmax",
        shell=True,
        check=True)


# The test cases for chunked attention.
@pytest.mark.parametrize('chunked_attention_size', [128, 256, 512, 1024],
                         ids=[
                             "chunked-attention-size-128",
                             "chunked-attention-size-256",
                             "chunked-attention-size-512",
                             "chunked-attention-size-1024"
                         ])
@pytest.mark.parametrize('input_layout', ["", "-paged-kv"],
                         ids=["packed-qkv", "paged-kv"])
def test_trtllm_chunked_attention(chunked_attention_size, input_layout):
    # only supported on hopper currently.
    if getSMVersion() != 90:
        pytest.skip("Chunked attention only supported on hopper currently.")

    subprocess.run(f"bin/fmha.exe -d 128 -b 4 -h 5 -fp16 -s 8192 -min-s 4096 \
        -chunked-attention-size {chunked_attention_size} {input_layout} ",
                   shell=True,
                   check=True)

    # Chunked context works with chunked attention.
    if input_layout == "-paged-kv":
        subprocess.run(
            f"bin/fmha.exe -d 128 -b 8 -h 5 -s-q 256 -s 8192 -min-s 4096 -fp16 \
            -chunked-attention-size {chunked_attention_size} -paged-kv",
            shell=True,
            check=True)
